{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed Data Classification with Quality and Domain Classifiers\n",
    "\n",
    "The notebook demonstrates the use of two classifiers for distributed data classification, including quality and domain classifiers. The quality classifier is used to classify the quality of the data, while the domain classifier is used to classify the domain of the data. These classifers help with annotation which helps data blending for foundation model training. \n",
    "\n",
    "The classifiers are accelerated using CrossFit,(https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: PYTHONWARNINGS=ignore\n"
     ]
    }
   ],
   "source": [
    "#### Silence Warnings (HuggingFace internal warnings)\n",
    "\n",
    "%env PYTHONWARNINGS=ignore\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dask_cuda import LocalCUDACluster\n",
    "from dask.distributed import Client\n",
    "from nemo_curator import DomainClassifier, QualityClassifier\n",
    "from nemo_curator.datasets import DocumentDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster = LocalCUDACluster(rmm_async=True, rmm_pool_size=\"1GB\")\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the data file paths "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_file_path=\"/input_data_dir/\"\n",
    "output_file_path = \"output_data_dir/\"\n",
    "domain_model_path = \"domain_model.pth\"\n",
    "quality_model_path = \"quality_model.pth\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_type=\"DomainClassifier\" # or \"QualityClassifier\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 16 files\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.5 s, sys: 5.33 s, total: 15.8 s\n",
      "Wall time: 11.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "input_dataset = DocumentDataset.read_json(\n",
    "    input_file_path, backend=\"cudf\", add_filename=True\n",
    ")\n",
    "\n",
    "if classifier_type == \"DomainClassifier\":\n",
    "    domain_labels = [\n",
    "    \"Adult\",\n",
    "    \"Arts_and_Entertainment\",\n",
    "    \"Autos_and_Vehicles\",\n",
    "    \"Beauty_and_Fitness\",\n",
    "    \"Books_and_Literature\",\n",
    "    \"Business_and_Industrial\",\n",
    "    \"Computers_and_Electronics\",\n",
    "    \"Finance\",\n",
    "    \"Food_and_Drink\",\n",
    "    \"Games\",\n",
    "    \"Health\",\n",
    "    \"Hobbies_and_Leisure\",\n",
    "    \"Home_and_Garden\",\n",
    "    \"Internet_and_Telecom\",\n",
    "    \"Jobs_and_Education\",\n",
    "    \"Law_and_Government\",\n",
    "    \"News\",\n",
    "    \"Online_Communities\",\n",
    "    \"People_and_Society\",\n",
    "    \"Pets_and_Animals\",\n",
    "    \"Real_Estate\",\n",
    "    \"Science\",\n",
    "    \"Sensitive_Subjects\",\n",
    "    \"Shopping\",\n",
    "    \"Sports\",\n",
    "    \"Travel_and_Transportation\",\n",
    "    ]\n",
    "    classifier = DomainClassifier(\n",
    "        model_path=domain_model_path,\n",
    "        labels=domain_labels,\n",
    "        batch_size=1024,\n",
    "    )\n",
    "elif classifier_type == \"QualityClassifier\":\n",
    "    quality_labels = [\"High\", \"Medium\", \"Low\"]\n",
    "    model_file_name = \"quality_classifier.pth\"\n",
    "    classifier = QualityClassifier(\n",
    "        model_path=quality_model_path,\n",
    "        labels=quality_labels,\n",
    "        batch_size=1024,\n",
    "    )\n",
    "else:\n",
    "    raise ValueError(\"Invalid classifier type\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the  Classifier\n",
    "\n",
    "Dask operations are lazy, so the the classifier will not run until we call a eager operation like `to_json`, `compute` or `persist`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting domain classifier inference\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU: 0, Part: 1: 100%|██████████| 938/938 [00:09<00:00, 101.99it/s] \n",
      "GPU: 0, Part: 3: 100%|██████████| 938/938 [00:10<00:00, 92.36it/s] ]\n",
      "GPU: 0, Part: 0: 100%|██████████| 938/938 [00:10<00:00, 91.25it/s] ]\n",
      "GPU: 0, Part: 5: 100%|██████████| 938/938 [00:10<00:00, 88.82it/s] \n",
      "GPU: 0, Part: 14: 100%|██████████| 937/937 [00:10<00:00, 88.11it/s] \n",
      "GPU: 0, Part: 8: 100%|██████████| 937/937 [00:10<00:00, 85.46it/s] ]\n",
      "GPU: 0, Part: 9: 100%|██████████| 937/937 [00:10<00:00, 86.16it/s] \n",
      "GPU: 0, Part: 4: 100%|██████████| 938/938 [00:10<00:00, 85.65it/s]]\n",
      "GPU: 0, Part: 11: 100%|██████████| 937/937 [00:11<00:00, 83.73it/s] \n",
      "GPU: 0, Part: 6: 100%|██████████| 938/938 [00:11<00:00, 83.62it/s]\n",
      "GPU: 0, Part: 10: 100%|██████████| 937/937 [00:11<00:00, 81.27it/s] \n",
      "GPU: 0, Part: 2: 100%|██████████| 938/938 [00:12<00:00, 72.59it/s]]\n",
      "GPU: 0, Part: 7: 100%|██████████| 937/937 [00:13<00:00, 71.75it/s]\n",
      "GPU: 0, Part: 12: 100%|██████████| 937/937 [00:13<00:00, 69.12it/s]\n",
      "GPU: 0, Part: 15: 100%|██████████| 937/937 [00:13<00:00, 68.47it/s]\n",
      "GPU: 0, Part: 13: 100%|██████████| 937/937 [00:14<00:00, 66.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing to disk complete for 16 partitions\n",
      "CPU times: user 2.34 s, sys: 2.24 s, total: 4.58 s\n",
      "Wall time: 17.2 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "result_dataset = classifier(dataset=input_dataset)\n",
    "result_dataset.to_json(output_file_dir=output_file_path, write_to_filename=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inspect the Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading 16 files\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>adlr_id</th>\n",
       "      <th>domain_pred</th>\n",
       "      <th>filename</th>\n",
       "      <th>id</th>\n",
       "      <th>pred</th>\n",
       "      <th>source_id</th>\n",
       "      <th>split_id</th>\n",
       "      <th>text</th>\n",
       "      <th>url</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>cc-2022-40-0431053204</td>\n",
       "      <td>Online_Communities</td>\n",
       "      <td>00.jsonl</td>\n",
       "      <td>a8083fe4-525d-4888-8513-b91f43bd8ee1</td>\n",
       "      <td>Online_Communities</td>\n",
       "      <td>crawl-data-CC-MAIN-2022-40-segments-1664030336...</td>\n",
       "      <td>lambada-0003225258-0000</td>\n",
       "      <td>Having been a community leader—and member—for ...</td>\n",
       "      <td>https://lisalarter.com/7-tips-for-building-ste...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>cc-2022-40-0510168267</td>\n",
       "      <td>Finance</td>\n",
       "      <td>00.jsonl</td>\n",
       "      <td>559febdc-cb7f-4217-897a-c8dac325123b</td>\n",
       "      <td>Finance</td>\n",
       "      <td>crawl-data-CC-MAIN-2022-40-segments-1664030337...</td>\n",
       "      <td>lambada-0003918122-0000</td>\n",
       "      <td>Zelle is a way of sending money to almost anyo...</td>\n",
       "      <td>https://oregonmassageandwellnessclinic.com/app...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 adlr_id         domain_pred  filename  \\\n",
       "0  cc-2022-40-0431053204  Online_Communities  00.jsonl   \n",
       "1  cc-2022-40-0510168267             Finance  00.jsonl   \n",
       "\n",
       "                                     id                pred  \\\n",
       "0  a8083fe4-525d-4888-8513-b91f43bd8ee1  Online_Communities   \n",
       "1  559febdc-cb7f-4217-897a-c8dac325123b             Finance   \n",
       "\n",
       "                                           source_id                 split_id  \\\n",
       "0  crawl-data-CC-MAIN-2022-40-segments-1664030336...  lambada-0003225258-0000   \n",
       "1  crawl-data-CC-MAIN-2022-40-segments-1664030337...  lambada-0003918122-0000   \n",
       "\n",
       "                                                text  \\\n",
       "0  Having been a community leader—and member—for ...   \n",
       "1  Zelle is a way of sending money to almost anyo...   \n",
       "\n",
       "                                                 url  \n",
       "0  https://lisalarter.com/7-tips-for-building-ste...  \n",
       "1  https://oregonmassageandwellnessclinic.com/app...  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_dataset = DocumentDataset.read_json(output_file_path, backend=\"cudf\", add_filename=True)\n",
    "output_dataset.df.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Cleanup the output file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf $output_file_path"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NeMo-Curator-env-2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
